Sk refactoring#394
Conversation
- Introduced learner management in DoubleMLScalar with properties for learner names and instances. - Added abstract method `set_learners` to enforce learner setting in subclasses. - Updated PLR to utilize the new learner management system, including validation checks for learner instances. - Refactored tests to align with the new learner management approach, ensuring proper exception handling and validation.
…nd utility functions
- Implemented the IRM class for double machine learning with interactive regression models in irm_scalar.py. - Added core estimation tests for IRM scalar in test_irm_scalar.py. - Created exception handling tests for IRM scalar in test_irm_scalar_exceptions.py. - Developed tests for handling external predictions in test_irm_scalar_external_predictions.py. - Added return type validation tests for IRM scalar in test_irm_scalar_return_types.py. - Compared the new IRM scalar implementation against the existing DoubleMLIRM in test_irm_scalar_vs_irm.py.
…standards, error handling, performance guidelines, and testing conventions.
…d tests for cluster-based sample splitting and external prediction validation.
…rs; enhance tests for return types and reset behavior.
…, testing, and scalar model test structure
…g; update tests for consistency
- Added `_sensitivity_element_est` method to `DoubleMLScalar`, `IRM`, and `PLR` classes to compute sensitivity elements including sigma2, nu2, and their influence functions. - Introduced `sensitivity_elements` property to retrieve computed sensitivity elements after model fitting. - Implemented validation checks for sensitivity elements in `DoubleMLScalar`. - Added exception handling for sensitivity analysis methods in `IRM` and `PLR` classes to ensure proper input types and values. - Created unit tests for sensitivity analysis, including checks for element shapes, bounds, and exception handling in both `IRM` and `PLR` models. - Ensured compatibility of sensitivity elements between scalar and legacy models in comparison tests.
…ndling in DoubleMLScalar
- Implemented `cate()` and `gate()` methods in `IRM` and `PLR` classes for estimating conditional average treatment effects. - Enhanced `DoubleMLBLP` to support per-rep basis for multi-rep scenarios. - Updated tests for `IRM` and `PLR` to validate new functionality, including checks for correct handling of multi-rep bases and group effects. - Improved validation of basis inputs in `DoubleMLBLP` to accept both single DataFrame and list of DataFrames. - Added new test cases to ensure robustness of the new features and backward compatibility with legacy models.
…sion and add comprehensive tests
… and enhance error handling in PLR and LearnerSpec validation
… checks into dedicated functions
Apply ruff D200/D213/D413 auto-fixes and add __init__ docstrings to DoubleMLVector and PLRVector.
…reamline sample comparison logic in tests
…bleMLScalar class
… DoubleMLScalar class
JanTeichertKluge
left a comment
There was a problem hiding this comment.
Thanks for the comprehensive refactoring. I think the changes are beneficial for the package as a whole. The multi-level hierarchy (DoubleMLBase → DoubleMLScalar → LinearScoreMixin → PLR/IRM) is very logical and clear. The centralization of _LEARNER_SPECS / validate_learner is also a particularly successful improvement over the previously _check_learner calls.
I have noted a few minor points in the comments.
| ml_l_info = self._learners["ml_l"] | ||
| self._learners["ml_g"] = LearnerInfo( | ||
| learner=clone(ml_l_info.learner), | ||
| is_classifier=ml_l_info.is_classifier, |
There was a problem hiding this comment.
Aren't we effectively skipping the validation step here? E.g. the ml_g has allow_classifier=False, but if ml_l is a classifier (which would be valid?), the clone will inherit is_classifier=True
| def __init__( | ||
| self, | ||
| obj_dml_data: DoubleMLBaseData, | ||
| score: str = "default", |
There was a problem hiding this comment.
Not sure about this. See comment on super().__init__()
| ) | ||
|
|
||
| # Call parent constructor | ||
| super().__init__(obj_dml_data) |
There was a problem hiding this comment.
We call super after setting score = "default".
| # Call parent constructor | ||
| super().__init__(obj_dml_data) | ||
|
|
||
| self._score = score |
There was a problem hiding this comment.
and we overwrite the score here. I think the score is, for every child lcass model like plr, irm etc., inherited and set to the specific defaults.
| if has_l and not has_g: | ||
| warnings.warn("For score='IV-type', ml_g not set. Cloning ml_l to ml_g.") | ||
| # Clone the learner and register with same info | ||
| from ..utils._learner import LearnerInfo |
There was a problem hiding this comment.
This import should also be moved
Summary
Introduce a new
DoubleMLScalar/DoubleMLVectorclass hierarchy alongside the existingDoubleMLAPI. The refactor delivers a cleaner, more testable design with explicit tuning, nuisance evaluation, and sensitivity analysis as first-class features. Two concrete scalar models (PLR,IRM) and one vector model (PLRVector) are ported, each backed by a comprehensive test suite that proves exact numerical equivalence with the legacy classes.Motivation
The legacy
DoubleMLbase class conflates single-parameter estimation, multi-treatment orchestration, and inference into one large class. This makes it hard to:The new hierarchy separates these concerns via a layered design with explicit hooks.
New Class Hierarchy
Plus a parallel multi-treatment track:
See doc/diagrams/architecture.md for the full UML and method-resolution diagrams.
Key Design Decisions
__init__accepts learners as optional kwargs (e.g.ml_l,ml_m,ml_g) for one-line construction, but they can also be configured (or replaced) later viaset_learners(...). Decoupling the two paths makes it possible to swap learners, re-tune, or re-fit without rebuilding the model._learner_namesas single source of truth — drives prediction-dict initialization and learner-availability checks; subclasses just declare the list.draw_sample_splitting()is its own step and can be called independently or re-drawn.fit()— orchestratesdraw_sample_splitting()→fit_nuisance_models()→estimate_causal_parameters(). Subclasses implement_nuisance_est()and_get_score_elements(); the mixin provides_est_causal_pars_and_se().fit()/fit_nuisance_models(), validated against_learner_names, and pre-filled before the cross-fitting loop.What's Included
Core infrastructure
LinearScoreMixinScalar models
DoubleMLPLRScalarwithcate(),gate(),_partial_out()DoubleMLIRMScalarwithcate(),gate(), weighted scores (array + dict-with-weights_bar)Vector models
DoubleMLPLRVector, validated against legacyDoubleMLPLRfor multi-treatmentCross-cutting features
tune_ml_models()onDoubleMLScalarwith pruning support,_LEARNER_PARAM_ALIASES(e.g. IRMml_g→[ml_g0, ml_g1]), and a_get_tuning_data()hook for subclass-specific tuning targetsnuisance_targets,nuisance_loss, andevaluate_learners(metric=...)with auto-defaulted RMSE / log-loss and NaN-aware masking_sensitivity_element_est()hook running over all reps post-fit, with framework-ready shapes; supports the fullsensitivity_analysis()pipelineDoubleMLBLPper-rep basis —basismay be a singlepd.DataFrame(shared) orlist[pd.DataFrame]of lengthn_rep. Also fixes a multi-rep / multi-column bug in legacyDoubleMLPLR.cate()(doubleml/utils/blp.py)Test suites
Every scalar model ships with the mandatory 5-file structure plus dedicated files for tuning, evaluation, and sensitivity:
test_<model>_scalar.py— 3-sigma estimation accuracytest_<model>_scalar_return_types.py— property types/shapestest_<model>_scalar_exceptions.py— input validationtest_<model>_scalar_vs_<model>.py— exact match with legacy (rtol=1e-9)test_<model>_scalar_external_predictions.py— external-prediction equivalencetest_<model>_scalar_tune_ml_models.py— Optuna tuningtest_<model>_scalar_evaluate_learners.py— nuisance loss / metricstest_<model>_scalar_sensitivity.py— sensitivity bounds & monotonicitytest_<model>_scalar_cate_gate.py— CATE/GATE (PLR & IRM)PLR vector ships with 5 corresponding test files. Plus shared scalar-level tests in doubleml/tests/ (cluster splits, fit, set-sample-splitting, tune-pruning, tune-exceptions, ext-predictions).
Tooling & docs
Feature Parity with Legacy Classes
cate()/gate()(PLR + IRM)_partial_out()(PLR)weights_bar(IRM)_check_smpls_dependent_inputs()hookpolicy_tree()(IRM)trimming_rule/trimming_thresholddeprecated propsps_processor_configBackwards Compatibility
DoubleMLPLR,DoubleMLIRM, …) remain unchanged and pass their existing test suites.DoubleMLPLR.cate()(basis * D_tildemis-broadcast forn_rep > 1andd_basis > 1) is fixed via the new BLP per-rep API.Test Plan
pytest -m cipasses locallypytest doubleml/plm/tests/ doubleml/irm/tests/ -v— full module suites for both refactored and legacy classespytest doubleml/tests/test_scalar_*.py -v— shared scalar infrastructure testsblack .,ruff check .,mypy doublemlclean (pre-existing mypy errors not introduced by this branch)summary,confint,bootstrap, andsensitivity_analysis()on a fittedDoubleMLPLRScalarandDoubleMLIRMScalar*_scalar_vs_*.py) atrtol=1e-9Follow-ups (out of scope)
DoubleMLIRMVectorDoubleMLPLIVScalar,DoubleMLPLPRScalarDID,DIDCSBinary,DIDMulti)DoubleMLVectorbase-class testspolicy_tree()port